-
Notifications
You must be signed in to change notification settings - Fork 12
[tvm-ffi] TVMFFIBuilder #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tvm-ffi] TVMFFIBuilder #111
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughA TVM-FFI based CUDA kernel builder (TVMFFIBuilder) is added and registered ahead of the pybind11 CUDA builder. Supporting changes include a TVMFFIRunnable destination-passing execution model, Definition API renames (const_axes/var_axes), a new is_cuda_available() utility, pyproject dependency on apache-tvm-ffi, tests for TVMFFIBuilder, and ancillary CUDA builder and package-name adjustments. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Builder as TVMFFIBuilder
participant FS as FileSystem/Cache
participant TVM as TVM-FFI
participant Runnable as TVMFFIRunnable
Client->>Builder: build(solution)
activate Builder
Builder->>Builder: can_build(solution)?
Builder->>Builder: _make_key / _get_build_path
Builder->>FS: check cache & write sources
Builder->>Builder: _get_language / _get_entry_symbol
Builder->>TVM: tvm_ffi.cpp.build(...) -- compile
Builder->>TVM: tvm_ffi.load_module(...) -- load
Builder->>Builder: _make_runnable(module, solution)
Builder-->>Client: TVMFFIRunnable(fn, meta)
deactivate Builder
Client->>Runnable: __call__(inputs)
activate Runnable
Runnable->>Runnable: get_output_shapes(inputs)
Runnable->>Runnable: allocate output_tensors
Runnable->>Runnable: call_dest(inputs, outputs)
Runnable-->>Client: output_tensor(s)
deactivate Runnable
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates TVM-FFI into the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces TVM-FFI integration through TVMFFIBuilder and PrebuiltLibraryManager, which is a great addition for kernel compilation and caching. The implementation is solid, with good test coverage and a helpful example. I've found a critical bug in argument handling for the compiled kernels and a high-severity issue in the prebuilt library path logic that could lead to incorrect behavior. I've also included a few medium-severity suggestions to improve an example's efficiency, enhance error handling, and add a test case for the path logic bug. Overall, great work on this feature.
| raise BuildError(f"Symbol '{symbol}' not found in module") from e | ||
|
|
||
| # Create keyword adapter to match definition interface | ||
| arg_order = list(defn.inputs.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arg_order is constructed using only defn.inputs. This will lead to incorrect arguments being passed to the native function, as it omits outputs and other scalar arguments (axes). Based on the example provided, the arguments should include inputs, outputs, and axes values.
| arg_order = list(defn.inputs.keys()) | |
| arg_order = list(defn.inputs.keys()) + list(defn.outputs.keys()) + list(defn.axes.keys()) |
flashinfer_bench/compile/prebuilt.py
Outdated
| def get_cache_dir(self) -> str: | ||
| """Get the cache directory for compiling new libraries. | ||
| Returns the last path in search_paths, which is always the local cache. | ||
| """ | ||
| cache_dir = self._search_paths[-1] # Always the local cache | ||
| os.makedirs(cache_dir, exist_ok=True) | ||
| return cache_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method incorrectly assumes the cache directory is always the last path in self._search_paths. This assumption fails if extra_paths are provided and the cache directory is also specified with higher priority (e.g., via FIB_PREBUILT_PATH). This can lead to compiled libraries being written to the wrong directory.
To fix this, the cache directory should be explicitly managed instead of being inferred from the search path order. A more robust approach would be:
- Store the cache directory path in an instance variable in
__init__. _build_search_pathsshould ensure this cache path is included in the search paths.get_cache_dirshould then simply return the stored instance variable and ensure it exists.
examples/tvm_ffi_example.py
Outdated
| a_jax = jnp.array(a_torch.cpu().numpy()) | ||
| b_jax = jnp.array(b_torch.cpu().numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current method of creating JAX tensors from PyTorch tensors is inefficient as it involves a gpu -> cpu -> gpu roundtrip via .cpu().numpy(). For this example, it's clearer and more efficient to generate random JAX tensors directly on the device, similar to the CuPy example. This makes the test case for each framework independent and avoids the performance overhead of data transfer.
Note: You'll need to add import jax at the beginning of the try block for this suggestion to work.
| a_jax = jnp.array(a_torch.cpu().numpy()) | |
| b_jax = jnp.array(b_torch.cpu().numpy()) | |
| key = jax.random.PRNGKey(0) | |
| a_jax, b_jax = jax.random.normal(key, (2, n), dtype=jnp.float32) |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad Exception and silently passing with pass can hide important errors. If an error occurs here, dependency resolution might fail silently, leading to compilation errors later that are hard to debug. It would be better to catch more specific exceptions (e.g., ModuleNotFoundError if a package for resources.files doesn't exist) or at least log a warning that a dependency path could not be resolved.
tests/compile/test_prebuilt.py
Outdated
| def test_get_cache_dir(self): | ||
| """Test getting cache directory.""" | ||
| manager = PrebuiltLibraryManager() | ||
| cache_dir = manager.get_cache_dir() | ||
|
|
||
| # Should exist and be writable | ||
| assert os.path.exists(cache_dir) | ||
| assert os.path.isdir(cache_dir) | ||
| assert os.access(cache_dir, os.W_OK) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test for get_cache_dir is good, but it doesn't cover a potential bug where an incorrect cache directory could be returned. This can happen if the cache path is also specified as a higher-priority path (e.g., via FIB_PREBUILT_PATH) and extra_paths are also provided.
Please consider adding a test case to cover this scenario, which would fail with the current implementation and pass after fixing the issue in PrebuiltLibraryManager. Here is a suggested test:
def test_get_cache_dir_priority(self, monkeypatch):
"""Test get_cache_dir returns correct path even with higher priority overrides."""
from flashinfer_bench.env import get_fib_cache_path
import os
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
extra_dir = os.path.join(tmpdir, "extra")
os.makedirs(extra_dir)
cache_dir = os.path.join(get_fib_cache_path(), "tvm_ffi")
monkeypatch.setenv("FIB_PREBUILT_PATH", cache_dir)
manager = PrebuiltLibraryManager(extra_paths=[extra_dir])
assert manager.get_cache_dir() == cache_dirThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Create keyword adapter to match definition interface | ||
| arg_order = list(defn.inputs.keys()) | ||
|
|
||
| def _kw_adapter(**kwargs): | ||
| args = [kwargs[name] for name in arg_order] | ||
| return fn(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop output and scalar arguments when invoking TVM FFI kernels
The adapter returned from _make_runnable only forwards values for defn.inputs (arg_order = list(defn.inputs.keys())). Any keyword such as outputs or axis/scalar parameters are ignored, so calls like runnable(a=a_torch, b=b_torch, c=c_torch, n=n) are translated to vector_add(a, b) and raise a TypeError or write to uninitialised memory. Kernel definitions always contain outputs and usually dimension arguments, so the builder cannot execute compiled libraries. The adapter should pass through all required kwargs (inputs, outputs, scalars) or match the entry-point signature explicitly.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (13)
flashinfer_bench/compile/registry.py (1)
55-64: Builder ordering LGTM; add a tiny trace to see which backend was chosen at runtime.Priority Python > Triton > TVM‑FFI > CUDA is sensible. Consider logging the chosen builder once in BuilderRegistry.build for observability during benchmarks.
flashinfer_bench/compile/builders/__init__.py (1)
4-6: Sort all to satisfy Ruff RUF022.Apply isort-style ordering.
-from .tvm_ffi_builder import TVMFFIBuilder - -__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +from .tvm_ffi_builder import TVMFFIBuilder + +__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"]Note: if you want strict lint compliance, change to alphabetical order:
-__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +__all__ = ["CUDABuilder", "PythonBuilder", "TVMFFIBuilder", "TritonBuilder"]As per static analysis hints.
flashinfer_bench/compile/__init__.py (1)
7-7: New exports are fine; keep all sorted to appease Ruff.Sort entries alphabetically to pass RUF022.
-__all__ = [ - "Builder", - "BuildError", - "BuilderRegistry", - "Runnable", - "get_builder_registry", - "PrebuiltLibraryManager", - "get_prebuilt_manager", -] +__all__ = [ + "Builder", + "BuilderRegistry", + "BuildError", + "PrebuiltLibraryManager", + "Runnable", + "get_builder_registry", + "get_prebuilt_manager", +]examples/tvm_ffi_example.py (2)
84-101: JAX section likely runs on CPU arrays; ensure GPU device or skip.For a CUDA kernel, place arrays on a GPU device before calling, or skip if no GPU is present to avoid runtime errors.
- import jax.numpy as jnp + import jax, jax.numpy as jnp @@ - a_jax = jnp.array(a_torch.cpu().numpy()) - b_jax = jnp.array(b_torch.cpu().numpy()) - c_jax = jnp.empty((n,), dtype=jnp.float32) + gpus = jax.devices("gpu") + if not gpus: + raise ImportError("No JAX GPU device") + dev = gpus[0] + a_jax = jax.device_put(jnp.array(a_torch.cpu().numpy()), dev) + b_jax = jax.device_put(jnp.array(b_torch.cpu().numpy()), dev) + c_jax = jax.device_put(jnp.empty((n,), dtype=jnp.float32), dev)
9-13: Remove unused import.flashinfer_bench as fib is unused.
-import flashinfer_bench as fib from flashinfer_bench.compile import get_builder_registryflashinfer_bench/compile/prebuilt.py (3)
15-21: Docstring search order outdated.The implementation also supports extra_paths and appends the local cache as step 4. Update the docstring to reflect the actual order.
103-110: Edge case: ensure get_cache_dir always returns the actual cache path even if it appeared earlier.If an identical cache path is already in search_paths (e.g., via env/extra_paths), the “last element” assumption may break. Compute and return the canonical cache path directly.
- cache_dir = self._search_paths[-1] # Always the local cache - os.makedirs(cache_dir, exist_ok=True) - return cache_dir + cache_dir = os.path.join(get_fib_cache_path(), "tvm_ffi") + os.makedirs(cache_dir, exist_ok=True) + return cache_dir
88-101: Cross‑platform: only searches for .so.Support .dylib (macOS) and .dll (Windows) to make prebuilt discovery portable.
- # Try both with and without .so extension - if not lib_name.endswith(".so"): - filename = f"{lib_name}.so" - else: - filename = lib_name - - for search_path in self._search_paths: - lib_path = os.path.join(search_path, filename) - if os.path.exists(lib_path): - logger.debug(f"Found prebuilt library: {lib_path}") - return lib_path + # Try common extensions across platforms + candidates = [lib_name] if any(lib_name.endswith(ext) for ext in (".so", ".dylib", ".dll")) \ + else [f"{lib_name}{ext}" for ext in (".so", ".dylib", ".dll")] + for search_path in self._search_paths: + for fname in candidates: + lib_path = os.path.join(search_path, fname) + if os.path.exists(lib_path): + logger.debug(f"Found prebuilt library: {lib_path}") + return lib_pathtests/compile/test_prebuilt.py (1)
79-93: Make tests portable across OS library suffixes.Hardcoding .so breaks on Windows/macOS. Parameterize the extension or probe via multiple suffixes.
- lib_path = os.path.join(tmpdir, "test_lib.so") + import sys + ext = ".dll" if sys.platform.startswith("win") else (".dylib" if sys.platform == "darwin" else ".so") + lib_path = os.path.join(tmpdir, f"test_lib{ext}") @@ - found2 = manager.find("test_lib.so") + found2 = manager.find(f"test_lib{ext}")tests/compile/test_tvm_ffi_builder.py (2)
178-179: Escape regex metacharacters in match=.Use a raw string or escape dots so the message matches literally.
-with pytest.raises(BuildError, match="No .cu CUDA sources"): +with pytest.raises(BuildError, match=r"No \.cu CUDA sources"):As per static analysis hint (RUF043).
114-122: Avoid relying on private method in test (builder._get_lib_path).Accessing a private method couples tests to internals. Prefer asserting on the returned path from build (e.g., via runnable.meta["cache_dir"] and expected filename) or expose a small public helper.
- lib_path = builder._get_lib_path(simple_solution) + cache_dir = builder._prebuilt_manager.get_cache_dir() + # Derive expected filename from solution name if needed, or assert any *.so exists under cache_dir. + lib_path = next((p for p in (os.path.join(cache_dir, f) for f in os.listdir(cache_dir)) if p.endswith(".so")), None)flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
49-64: Use unpacking for list concatenation.On line 64, prefer unpacking over concatenation for better performance and readability.
Apply this diff:
- elif sys.platform == "win32": - ldflags = [f"/LIBPATH:{lib_path}"] + lib_names + elif sys.platform == "win32": + ldflags = [f"/LIBPATH:{lib_path}", *lib_names]
66-69: Consider logging exceptions during package path discovery.The bare
except Exception: passsilently swallows all errors during package path discovery, making it difficult to diagnose issues when dependencies are misconfigured or packages are malformed.Apply this diff:
except Exception: - pass + logger.debug(f"Failed to discover package paths for {pkg_name}") return include_path, ldflags
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/tvm_ffi_example.py(1 hunks)flashinfer_bench/compile/__init__.py(1 hunks)flashinfer_bench/compile/builders/__init__.py(1 hunks)flashinfer_bench/compile/builders/tvm_ffi_builder.py(1 hunks)flashinfer_bench/compile/prebuilt.py(1 hunks)flashinfer_bench/compile/registry.py(1 hunks)pyproject.toml(1 hunks)tests/compile/test_prebuilt.py(1 hunks)tests/compile/test_tvm_ffi_builder.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
flashinfer_bench/compile/registry.py (3)
flashinfer_bench/compile/builders/cuda_builder.py (1)
CUDABuilder(122-238)flashinfer_bench/compile/builders/python_builder.py (1)
PythonBuilder(20-104)flashinfer_bench/compile/builders/triton_builder.py (1)
TritonBuilder(18-51)
tests/compile/test_prebuilt.py (1)
flashinfer_bench/compile/prebuilt.py (5)
PrebuiltLibraryManager(14-110)get_prebuilt_manager(113-131)search_paths(71-73)find(75-101)get_cache_dir(103-110)
flashinfer_bench/compile/builders/__init__.py (1)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
TVMFFIBuilder(123-287)
flashinfer_bench/compile/__init__.py (1)
flashinfer_bench/compile/prebuilt.py (2)
PrebuiltLibraryManager(14-110)get_prebuilt_manager(113-131)
examples/tvm_ffi_example.py (2)
flashinfer_bench/compile/registry.py (2)
get_builder_registry(52-64)build(26-31)flashinfer_bench/data/solution.py (3)
BuildSpec(60-91)SourceFile(27-57)SupportedLanguages(12-24)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (5)
flashinfer_bench/compile/builder.py (3)
Builder(52-95)BuildError(48-49)create_pkg_name(33-45)flashinfer_bench/compile/prebuilt.py (3)
PrebuiltLibraryManager(14-110)get_cache_dir(103-110)find(75-101)flashinfer_bench/compile/runnable.py (1)
Runnable(6-38)flashinfer_bench/data/solution.py (2)
SourceFile(27-57)SupportedLanguages(12-24)flashinfer_bench/logging.py (1)
get_logger(9-12)
flashinfer_bench/compile/prebuilt.py (2)
flashinfer_bench/env.py (1)
get_fib_cache_path(46-57)flashinfer_bench/logging.py (1)
get_logger(9-12)
tests/compile/test_tvm_ffi_builder.py (4)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (4)
TVMFFIBuilder(123-287)_verify_tvm_ffi(23-31)can_build(155-156)_get_lib_path(186-189)flashinfer_bench/data/solution.py (3)
BuildSpec(60-91)SourceFile(27-57)SupportedLanguages(12-24)flashinfer_bench/compile/registry.py (1)
build(26-31)flashinfer_bench/compile/builder.py (1)
BuildError(48-49)
🪛 GitHub Actions: .github/workflows/linting.yaml
flashinfer_bench/compile/builders/tvm_ffi_builder.py
[error] 27-27: F401 'tvm_ffi.cpp' imported but unused. Remove unused import to satisfy linter.
🪛 Ruff (0.14.3)
flashinfer_bench/compile/builders/__init__.py
6-6: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer_bench/compile/__init__.py
11-19: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer_bench/compile/builders/tvm_ffi_builder.py
29-29: Consider moving this statement to an else block
(TRY300)
64-64: Consider [f"/LIBPATH:{lib_path}", *lib_names] instead of concatenation
Replace with [f"/LIBPATH:{lib_path}", *lib_names]
(RUF005)
66-67: try-except-pass detected, consider logging the exception
(S110)
66-66: Do not catch blind exception: Exception
(BLE001)
169-169: Avoid specifying long messages outside the exception class
(TRY003)
173-173: Avoid specifying long messages outside the exception class
(TRY003)
181-183: Avoid specifying long messages outside the exception class
(TRY003)
200-203: Avoid specifying long messages outside the exception class
(TRY003)
215-215: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Do not catch blind exception: Exception
(BLE001)
254-254: Avoid specifying long messages outside the exception class
(TRY003)
260-260: Avoid specifying long messages outside the exception class
(TRY003)
269-269: Avoid specifying long messages outside the exception class
(TRY003)
tests/compile/test_tvm_ffi_builder.py
178-178: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (9)
pyproject.toml (1)
26-32: Wheel availability confirmed for specified targets—review comment is correct.The verification shows wheels exist for your CI targets:
- Linux/x86_64: cp310, cp311, cp312-abi3 (covers 3.13 via stable ABI)
- Python 3.10–3.13: All versions covered via direct wheels (cp310, cp311) and abi3 wheels (cp312+)
The package name and import path are correct per PyPI documentation. No action needed.
flashinfer_bench/compile/builders/tvm_ffi_builder.py (8)
1-21: LGTM!The imports and constants are well-structured and appropriate for the TVM-FFI builder implementation.
72-86: LGTM!The dependency configuration and detection patterns are well-structured and comprehensive.
89-96: LGTM!The dependency discovery logic correctly populates include paths and linker flags.
99-120: LGTM!The dependency checking logic is thorough, with appropriate comment stripping to avoid false positives and an early optimization check.
123-156: LGTM!The class initialization and availability checks are well-designed, with appropriate caching of the TVM-FFI availability check and proper integration with the PrebuiltLibraryManager.
158-189: LGTM!The helper methods are well-structured with clear error handling. The error messages provide good context for debugging.
191-210: LGTM!The dependency collection logic correctly validates that required dependencies are available and provides clear error messages when they're missing.
212-287: LGTM!The build and runnable creation logic is well-designed with proper fallback handling. The bare exception catch on line 230 is appropriate here—if loading a prebuilt library fails for any reason, the builder correctly falls back to recompilation. The keyword-to-positional adapter ensures compatibility with the Definition interface.
examples/tvm_ffi_example.py
Outdated
| # 3. Build with TVM-FFI (compiles on first run, cached afterwards) | ||
| print("Building kernel with TVM-FFI...") | ||
| builder_registry = get_builder_registry() | ||
| runnable = builder_registry.build(definition, solution) | ||
| print(f"✓ Built successfully: {runnable.meta}") | ||
|
|
||
| # 4. Use in PyTorch (DLPack auto-conversion) | ||
| print("\n=== PyTorch Test ===") | ||
| n = 1000000 | ||
| a_torch = torch.randn(n, device="cuda", dtype=torch.float32) | ||
| b_torch = torch.randn(n, device="cuda", dtype=torch.float32) | ||
| c_torch = torch.empty(n, device="cuda", dtype=torch.float32) | ||
|
|
||
| runnable(a=a_torch, b=b_torch, c=c_torch, n=n) | ||
|
|
||
| expected = a_torch + b_torch | ||
| torch.testing.assert_close(c_torch, expected, rtol=1e-5, atol=1e-5) | ||
| print("✓ PyTorch: Result correct") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: runnable drops c and n; call will pass only a,b.
The adapter in TVMFFIBuilder (and CUDA builder) builds args from defn.inputs only, so kwargs c and n are ignored. Your kernel signature requires (a, b, c, n) and will fail. Fix the adapter to include outputs and axes in a defined order (inputs + outputs + axes), or change the example to match a two‑arg typed function.
- runnable(a=a_torch, b=b_torch, c=c_torch, n=n)
+ # After fixing adapter to pass inputs+outputs+axes:
+ runnable(a=a_torch, b=b_torch, c=c_torch, n=n)Follow-up suggested change in builder (flashinfer_bench/compile/builders/tvm_ffi_builder.py):
- arg_order = list(defn.inputs.keys())
- def _kw_adapter(**kwargs):
- args = [kwargs[name] for name in arg_order]
- return fn(*args)
+ # Pass inputs + outputs + named axes (if provided) to match C/CUDA signature.
+ arg_order = [*defn.inputs.keys(), *defn.outputs.keys(), *sorted(getattr(defn, "axes", {}).keys())]
+ def _kw_adapter(**kwargs):
+ args = [kwargs[name] for name in arg_order if name in kwargs]
+ return fn(*args)Also consider documenting the expected argument order for C/CUDA entry points.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/tvm_ffi_example.py around lines 65 to 83, the runnable call
currently passes a, b, c, n as kwargs but the TVMFFI/CUDA adapter builds
argument lists from definition.inputs only, so c and n are dropped and the
kernel signature (a, b, c, n) will fail; fix the adapter in
flashinfer_bench/compile/builders/tvm_ffi_builder.py (and the CUDA builder) to
build positional args in the canonical order inputs + outputs + axes (or inputs
+ outputs + scalar/axis params) and ensure the runnable wrapper accepts kwargs
and maps them into that ordered positional list before calling the C entry
point, and update the example or add docs to clearly state the required argument
ordering for C/CUDA entry points.
| def _verify_tvm_ffi() -> bool: | ||
| """Check if tvm_ffi is available.""" | ||
| try: | ||
| import tvm_ffi | ||
| import tvm_ffi.cpp | ||
|
|
||
| return True | ||
| except ImportError: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unused import flagged by linter.
The import tvm_ffi.cpp on line 27 is flagged as unused by the linter. Since this import is only for verification purposes and tvm_ffi.cpp is imported again in the _build method where it's actually used, you can remove it here.
Apply this diff:
def _verify_tvm_ffi() -> bool:
"""Check if tvm_ffi is available."""
try:
import tvm_ffi
- import tvm_ffi.cpp
return True
except ImportError:
return False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _verify_tvm_ffi() -> bool: | |
| """Check if tvm_ffi is available.""" | |
| try: | |
| import tvm_ffi | |
| import tvm_ffi.cpp | |
| return True | |
| except ImportError: | |
| return False | |
| def _verify_tvm_ffi() -> bool: | |
| """Check if tvm_ffi is available.""" | |
| try: | |
| import tvm_ffi | |
| return True | |
| except ImportError: | |
| return False |
🧰 Tools
🪛 GitHub Actions: .github/workflows/linting.yaml
[error] 27-27: F401 'tvm_ffi.cpp' imported but unused. Remove unused import to satisfy linter.
🪛 Ruff (0.14.3)
29-29: Consider moving this statement to an else block
(TRY300)
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builders/tvm_ffi_builder.py around lines 23 to 31,
the helper _verify_tvm_ffi currently imports tvm_ffi.cpp but that import is
unused and flagged by the linter; modify the function to only import tvm_ffi
(not tvm_ffi.cpp) inside the try block and return True on success, leaving the
ImportError handling unchanged so the actual tvm_ffi.cpp import remains in the
_build method where it is used.
bbe3a61 to
f1dc96c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (6)
flashinfer_bench/compile/builder.py (1)
33-45: Hash suffix length change is reasonable; note cache key impactBumping the SHA-1 suffix from 4 to 6 hex chars improves collision resistance for package names and should be harmless, but it will stop reusing any artifacts keyed with the old suffix length (they’ll just be rebuilt and old cache entries remain orphaned on disk). If that cache is user-visible or disk-sensitive, consider a one-time cleanup or a short note in release docs.
flashinfer_bench/compile/runnable.py (3)
10-20: Runnable meta is now stored by reference instead of copiedStoring
metadirectly instead of copying it means any external mutations to the dict (or sharing the same dict across multiple runnables) will be visible here. Ifmetais treated as immutable metadata this is fine, but if callers may mutate it, consider documenting the expectation or reinstating a defensive copy (dict(meta or {})) to avoid surprising aliasing.
75-82: Unify close() behavior with base Runnable
TVMFFIRunnable.close()currently reimplements the close logic without thetry/finallyguard used inRunnable.close(). While functionally similar in the non-error case, usingsuper().close()would keep the behavior consistent (e.g., ensuring_closeris nulled even if it raises):- def close(self) -> None: - if self._closer: - self._closer() - self._closer = None + def close(self) -> None: + super().close()This also centralizes any future changes to close semantics in the base class.
41-73: Confirmed: torch.empty device optimization is valid for torch 2.8.0The dest-passing logic is correct and well-structured:
- Uses
Definition.get_var_valueson input tensor shapes, thenget_output_shapesto size outputs- Allocates
torch.emptywith appropriate dtype viadtype_str_to_torch_dtypeand places outputs on the first input's device- Returns single tensor vs. list based on output count, mirroring base
RunnablebehaviorTwo clarifications worth documenting:
- All input kwargs are assumed to be torch tensors with
.shapeand.deviceattributes (non-tensor or mixed-device inputs would misbehave)- Outputs are always allocated on the first input's device; if heterogeneous devices or explicit selection appears later, consider adding a device parameter or validation
PyTorch 2.8.0 supports torch.empty with a device parameter that allocates the tensor directly on the specified device, so optimize the allocation:
- output_tensors[name] = torch.empty( - shape, dtype=dtype_str_to_torch_dtype(self._definition.outputs[name].dtype) - ).to(device) + output_tensors[name] = torch.empty( + shape, + dtype=dtype_str_to_torch_dtype(self._definition.outputs[name].dtype), + device=device, + )flashinfer_bench/compile/builders/__init__.py (1)
4-6: LGTM: TVMFFIBuilder integration.The import and export are correct and properly integrate the new TVM-FFI builder.
Note: Static analysis suggests sorting
__all__alphabetically for consistency. This is purely stylistic and can be deferred.-__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +__all__ = ["CUDABuilder", "PythonBuilder", "TVMFFIBuilder", "TritonBuilder"]flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
74-76: Optional: Simplify unused parameter.The
cpp_filesparameter is unused in_get_language. While this maintains API symmetry, consider simplifying if the parameter won't be needed.- def _get_language(self, cpp_files: List[str], cuda_files: List[str]) -> str: + def _get_language(self, cuda_files: List[str]) -> str: return "cuda" if len(cuda_files) > 0 else "cpp"And update the call site at line 112.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
flashinfer_bench/compile/builder.py(1 hunks)flashinfer_bench/compile/builders/__init__.py(1 hunks)flashinfer_bench/compile/builders/cuda_builder.py(7 hunks)flashinfer_bench/compile/builders/tvm_ffi_builder.py(1 hunks)flashinfer_bench/compile/registry.py(1 hunks)flashinfer_bench/compile/runnable.py(2 hunks)flashinfer_bench/data/definition.py(5 hunks)flashinfer_bench/data/trace.py(1 hunks)flashinfer_bench/utils.py(1 hunks)pyproject.toml(1 hunks)tests/compile/test_tvm_ffi_builder.py(1 hunks)tests/data/test_definition.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- flashinfer_bench/data/trace.py
🧰 Additional context used
🧬 Code graph analysis (7)
flashinfer_bench/compile/registry.py (1)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
TVMFFIBuilder(22-143)
flashinfer_bench/compile/runnable.py (2)
flashinfer_bench/data/definition.py (2)
get_var_values(239-279)get_output_shapes(343-363)flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(39-45)
flashinfer_bench/compile/builders/cuda_builder.py (3)
flashinfer_bench/utils.py (1)
is_cuda_available(57-59)flashinfer_bench/compile/builder.py (2)
_build(64-66)BuildError(48-49)flashinfer_bench/compile/runnable.py (1)
Runnable(9-38)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (5)
flashinfer_bench/compile/builder.py (8)
Builder(52-95)BuildError(48-49)create_pkg_name(33-45)can_build(59-61)_make_key(74-76)_make_closer(69-71)_build(64-66)build(78-85)flashinfer_bench/compile/runnable.py (2)
Runnable(9-38)TVMFFIRunnable(41-82)flashinfer_bench/data/solution.py (1)
SupportedLanguages(12-24)flashinfer_bench/env.py (1)
get_fib_cache_path(46-57)flashinfer_bench/compile/builders/cuda_builder.py (5)
can_build(135-136)_make_key(138-139)_make_closer(141-144)_kw_adapter(217-219)_build(146-233)
tests/data/test_definition.py (1)
flashinfer_bench/data/definition.py (2)
const_axes(194-202)var_axes(205-213)
tests/compile/test_tvm_ffi_builder.py (4)
flashinfer_bench/compile/builder.py (3)
BuildError(48-49)build(78-85)can_build(59-61)flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
TVMFFIBuilder(22-143)can_build(41-42)flashinfer_bench/data/solution.py (3)
BuildSpec(60-91)SourceFile(27-57)SupportedLanguages(12-24)flashinfer_bench/compile/runnable.py (1)
call_dest(75-77)
flashinfer_bench/compile/builders/__init__.py (1)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
TVMFFIBuilder(22-143)
🪛 Ruff (0.14.4)
flashinfer_bench/compile/builders/tvm_ffi_builder.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
71-71: Avoid specifying long messages outside the exception class
(TRY003)
74-74: Unused method argument: cpp_files
(ARG002)
81-83: Avoid specifying long messages outside the exception class
(TRY003)
93-93: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
131-131: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/data/definition.py
267-270: Avoid specifying long messages outside the exception class
(TRY003)
275-278: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/compile/builders/__init__.py
6-6: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🔇 Additional comments (13)
flashinfer_bench/utils.py (1)
57-60: CUDA availability helper follows PyTorch best practicesThe
is_cuda_available()wrapper correctly usestorch.cuda.is_available(), which remains the recommended way to check for CUDA in PyTorch 2.8. The implementation is clear and appropriate.tests/data/test_definition.py (1)
43-44: LGTM: Test updated for new API.The test correctly uses the renamed
const_axes()andvar_axes()property accessors, aligning with the API changes inflashinfer_bench/data/definition.py.flashinfer_bench/data/definition.py (4)
114-115: Good clarification on output naming constraints.The updated docstring explicitly states that output names must not overlap with input names, improving API clarity.
193-213: LGTM: Cleaner property-based API.Converting
get_const_axesandget_var_axestocached_propertyaccessors improves the API ergonomics and follows Python conventions.
215-237: LGTM: Consistent API renaming.The
var_axes_bindingsproperty maintains consistency with the other API changes while preserving the cached_property behavior.
239-279: LGTM: Robust variable axis resolution.The
get_var_valuesmethod correctly derives variable axis values from input shapes with proper consistency validation across inputs and completeness checks.Note: Static analysis suggests extracting exception messages to exception classes (TRY003). This is a stylistic preference; the current detailed error messages are valuable for debugging and can remain as-is.
flashinfer_bench/compile/builders/cuda_builder.py (2)
3-3: LGTM: Enhanced imports and logging.The additions of
logging,Optionaltyping, and centralizedis_cuda_available()utility improve code organization and maintainability.Also applies to: 10-10, 20-20, 24-24
27-27: LGTM: Improved error visibility.Adding explicit
Optionaltyping and logging warnings when resource discovery fails provides better debugging visibility compared to silent failures.Also applies to: 61-65
tests/compile/test_tvm_ffi_builder.py (1)
1-277: LGTM: Comprehensive test coverage.The test suite thoroughly exercises TVMFFIBuilder functionality including CPU/CUDA builds, caching behavior, and error paths. The test structure follows pytest conventions appropriately.
flashinfer_bench/compile/builders/tvm_ffi_builder.py (4)
53-72: LGTM: Robust source file handling.The
_write_sourcesmethod correctly writes source files, classifies them by extension, and includes defensive checks for edge cases (directories, no sources).
106-143: LGTM: Comprehensive build workflow.The
_buildmethod implements a robust workflow with proper error handling, caching, and metadata tracking. The build-then-load pattern is appropriate for TVM-FFI.
9-9: No changes needed - dependency and error handling verified.The codebase already addresses both concerns:
apache-tvm-ffi>=0.1.3is listed inpyproject.tomlline 31- Build failures are handled gracefully—
tvm_ffi.cpp.build()andtvm_ffi.load_module()calls (lines 115-125, 128-131) are wrapped in try-except blocks that raise descriptiveBuildErrormessagesThe module-level import of
tvm_ffiis acceptable for a declared dependency; failures are caught and handled appropriately during build execution.
86-104: No issues found—arg_order is correct for destination-passing style.The code is architecturally sound. TVMFFIRunnable uses destination-passing style: it pre-allocates output tensors in
__call__(line 68 in runnable.py), then passes both inputs and outputs to the kernel viacall_dest. The arg_order on line 96—combining inputs and outputs—correctly reflects this calling convention.Axes are not kernel parameters; they are symbolic dimension names extracted from tensor shapes at runtime. Including them in arg_order would be incorrect.
| class TVMFFIBuilder(Builder): | ||
| """Builder using TVM-FFI with automatic caching and multi-process sharing. | ||
| Build strategy: | ||
| 1. Check if .so exists in cache (multi-process safe) | ||
| 2. If not, compile with tvm_ffi.cpp.build_inline() to cache | ||
| 3. Load with tvm_ffi.load_module() | ||
| Benefits: | ||
| - Multi-process benchmark: Only first process compiles, others load from cache | ||
| - Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack) | ||
| - No JIT/AOT distinction: Smart caching handles both cases | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring to match actual API usage.
The docstring (line 27) mentions tvm_ffi.cpp.build_inline(), but the implementation uses tvm_ffi.cpp.build() (line 117). This inconsistency can confuse maintainers.
Update the docstring:
Build strategy:
1. Check if .so exists in cache (multi-process safe)
- 2. If not, compile with tvm_ffi.cpp.build_inline() to cache
+ 2. If not, compile with tvm_ffi.cpp.build() to cache
3. Load with tvm_ffi.load_module()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class TVMFFIBuilder(Builder): | |
| """Builder using TVM-FFI with automatic caching and multi-process sharing. | |
| Build strategy: | |
| 1. Check if .so exists in cache (multi-process safe) | |
| 2. If not, compile with tvm_ffi.cpp.build_inline() to cache | |
| 3. Load with tvm_ffi.load_module() | |
| Benefits: | |
| - Multi-process benchmark: Only first process compiles, others load from cache | |
| - Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack) | |
| - No JIT/AOT distinction: Smart caching handles both cases | |
| """ | |
| class TVMFFIBuilder(Builder): | |
| """Builder using TVM-FFI with automatic caching and multi-process sharing. | |
| Build strategy: | |
| 1. Check if .so exists in cache (multi-process safe) | |
| 2. If not, compile with tvm_ffi.cpp.build() to cache | |
| 3. Load with tvm_ffi.load_module() | |
| Benefits: | |
| - Multi-process benchmark: Only first process compiles, others load from cache | |
| - Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack) | |
| - No JIT/AOT distinction: Smart caching handles both cases | |
| """ |
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builders/tvm_ffi_builder.py around lines 22 to 34,
the docstring incorrectly references tvm_ffi.cpp.build_inline() while the
implementation uses tvm_ffi.cpp.build(); update the docstring to mention
tvm_ffi.cpp.build() (or a neutral description like "tvm_ffi.cpp.build()" or
"tvm_ffi.cpp.build() / build API") so it matches the actual API used, keeping
the rest of the build strategy and benefits text intact.
| dependencies = [ | ||
| "torch==2.8.0", | ||
| "pydantic>=2.0.0", | ||
| "safetensors>=0.5.0", | ||
| "flashinfer-python>=0.3.1", | ||
| "apache-tvm-ffi>=0.1.3", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
TVM-FFI dependency and version range
Adding apache-tvm-ffi>=0.1.3 as a hard project dependency aligns with the new TVM-FFI builder, but it also means TVM-FFI is always installed even for users who never hit that path. If TVM-FFI’s API/ABI isn’t very stable yet, you may want to consider a narrower range (e.g., <0.2 or an exact pin) or making it an extra instead of a mandatory dependency.
🌐 Web query:
Check apache-tvm-ffi’s documented compatibility guarantees and recommended version pinning strategy for downstream projects.
💡 Result:
Short summary of documented guarantees and recommended pinning:
-
ABI/versioning policy: Releases are in an RFC stage using 0.X.Y where a change in X is a C ABI–breaking change and Y is non-ABI changes. After the RFC period they will adopt normal Semantic Versioning (major.minor.patch). [1]
-
C ABI guarantees: The package exposes a stable C ABI surface (DLTensor, call conventions, etc.), but the distributed library is implemented in C++ and therefore depends on glibc/C++ runtime versions — binaries must be produced for manylinux/old-glibc to be broadly compatible. The docs explicitly instruct building on old-glibc (manylinux images) so consumers can run on newer glibc (forward-compatible). They also show how to verify GLIBC_ deps. [2]
-
Packaging guidance for downstreams: Use the official manylinux wheels from apache-tvm-ffi when possible; if you distribute native kernels/libs that link to libtvm_ffi.so, build your artifacts in a manylinux/older-glibc container and test across target distributions. The docs include concrete commands and CI/container recommendations. [2]
Recommended pinning strategy (based on their documented policy):
- Pin to an exact released artifact (exact 0.X.Y wheel) in downstream projects (e.g., apache-tvm-ffi==0.1.1) for reproducibility and to avoid unexpected ABI changes. During the RFC stage treat X bumps as ABI-breaking — do not float X (major) or X.Y ranges that could cross an X bump. After the project moves to semantic versioning, you may follow conservative range pinning (e.g., caret/compatible ranges) that avoid major bumps. [1][2]
Sources:
- Project & PyPI release policy / RFC-stage note. [1]
- TVM-FFI docs (installation, ABI overview, and C++ packaging / glibc compatibility guidance). [2]
Pin apache-tvm-ffi to exact version during RFC stage
The apache-tvm-ffi>=0.1.3 constraint is too loose. The project is currently in RFC-stage versioning where X bumps (e.g., 0.2.0) are C ABI-breaking changes. The upstream documentation explicitly recommends pinning to an exact released artifact (e.g., apache-tvm-ffi==0.1.3) during the RFC stage to avoid unexpected ABI changes. Use apache-tvm-ffi==0.1.3 instead of >=0.1.3 to maintain reproducibility and prevent ABI-incompatible updates.
🤖 Prompt for AI Agents
In pyproject.toml around lines 26 to 32, the apache-tvm-ffi dependency is
currently specified as a loose minimum constraint (>=0.1.3); change it to an
exact pin to the released artifact (apache-tvm-ffi==0.1.3) so the project uses a
reproducible, ABI-stable version during RFC-stage; update the dependency list
accordingly and run your dependency lock/install step to verify the new
constraint is applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/compile/test_tvm_ffi_builder.py (1)
94-94: Remove redundant import.
torchis already imported at line 7. This redundant import should be removed.Apply this diff:
- import torch -
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/compile/test_tvm_ffi_builder.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/compile/test_tvm_ffi_builder.py (4)
flashinfer_bench/compile/builder.py (3)
BuildError(48-49)build(78-85)can_build(59-61)flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
TVMFFIBuilder(22-143)can_build(41-42)flashinfer_bench/data/solution.py (3)
BuildSpec(60-91)SourceFile(27-57)SupportedLanguages(12-24)flashinfer_bench/compile/runnable.py (1)
call_dest(75-77)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
🔇 Additional comments (7)
tests/compile/test_tvm_ffi_builder.py (7)
1-408: Past review comments are no longer applicable.The two past review comments reference code that either doesn't exist in this file or has been corrected:
- The comment about
build_inlinevsload_inlinedoesn't apply to this test file - the builder implementation usestvm_ffi.cpp.build()andtvm_ffi.load_module().- The comment about debug code at lines 278-279 is outdated - those lines now contain the
test_call_dest_cpu()function definition.
18-84: LGTM! Comprehensive CPU kernel test.The test properly exercises TVMFFIBuilder with a CPU kernel using destination-passing style, verifies correctness with PyTorch tensors, and demonstrates the expected workflow.
92-168: LGTM! CUDA kernel test is well-structured.The test appropriately skips in CI environments and properly exercises CUDA kernel compilation and execution with GPU tensors.
175-211: LGTM! Good coverage of can_build logic.Both tests correctly verify that TVMFFIBuilder accepts CUDA language solutions and rejects non-CUDA solutions.
213-275: LGTM! Effective caching test.The test properly verifies that the builder caches compiled modules and reuses them on subsequent builds. The timing measurements provide useful diagnostic information without making the test brittle.
277-333: LGTM! Important test for destination-passing API.This test properly exercises the
call_dest()method with pre-allocated output tensors, which is a key feature of the TVMFFIRunnable interface.
335-404: LGTM! Good error handling coverage.Both tests properly verify that BuildError is raised for invalid configurations (missing entry point symbol and invalid file extensions), ensuring robust error handling.
f6680fd to
12f10eb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
22-34: Fix docstring inconsistency with actual implementation.The docstring on line 27 mentions
tvm_ffi.cpp.build_inline(), but the implementation on line 117 actually usestvm_ffi.cpp.build(). This inconsistency can confuse maintainers.Update the docstring:
Build strategy: 1. Check if .so exists in cache (multi-process safe) - 2. If not, compile with tvm_ffi.cpp.build_inline() to cache + 2. If not, compile with tvm_ffi.cpp.build() to cache 3. Load with tvm_ffi.load_module()
115-123: Remove or complete the incomplete comment.Line 116 contains an incomplete comment:
"Use build_inline instead of build to". This should either be completed with the rationale or removed to avoid confusion.Apply this fix:
try: - # Use build_inline instead of build to + # Build the TVM-FFI module with caching output_lib_path = tvm_ffi.cpp.build(
🧹 Nitpick comments (4)
flashinfer_bench/data/definition.py (1)
114-115: Consider adding validation for non-overlapping output/input names.The docstring now states that output names must not overlap with input names, but there's no corresponding validation in the model validators to enforce this constraint. If this is a hard requirement, consider adding a validator to prevent violations at definition-creation time.
Add a validator if this constraint should be enforced:
@model_validator(mode="after") def _validate_tensor_axis_references(self) -> "Definition": """Validate that tensor shapes reference defined axes. Ensures that all axis names used in input and output tensor shapes are properly defined in the axes dictionary. Raises ------ ValueError If any tensor shape references an undefined axis. """ all_tensors = {**self.inputs, **self.outputs} for tensor_name, tensor_spec in all_tensors.items(): if tensor_spec.shape is not None: for axis_name in tensor_spec.shape: if axis_name not in self.axes: tensor_type = "input" if tensor_name in self.inputs else "output" raise ValueError( f'{tensor_type.capitalize()} "{tensor_name}" references undefined ' f'axis "{axis_name}"' ) return self + + @model_validator(mode="after") + def _validate_no_overlapping_names(self) -> "Definition": + """Validate that output names do not overlap with input names.""" + overlapping = set(self.inputs.keys()) & set(self.outputs.keys()) + if overlapping: + raise ValueError( + f"Output names must not overlap with input names. Found: {overlapping}" + ) + return selfflashinfer_bench/compile/builders/__init__.py (1)
1-6: Consider sorting all alphabetically.The static analysis tool suggests sorting
__all__for consistency. While not critical, alphabetical ordering improves maintainability.Apply this diff:
-__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +__all__ = ["CUDABuilder", "PythonBuilder", "TVMFFIBuilder", "TritonBuilder"]flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
41-42: Consider adding TVM-FFI availability check.Similar to
CUDABuilder.can_build()which checks CUDA availability,TVMFFIBuilder.can_build()should verify that TVM-FFI is available in the environment. Currently, it only checks the language, which means compilation failures will occur at build time rather than being caught earlier bycan_build().Consider adding an availability check pattern similar to CUDABuilder:
class TVMFFIBuilder(Builder): + _tvm_ffi_available: bool = None + + @classmethod + def _get_tvm_ffi_available(cls) -> bool: + if cls._tvm_ffi_available is None: + try: + import tvm_ffi + import tvm_ffi.cpp + cls._tvm_ffi_available = True + except ImportError: + cls._tvm_ffi_available = False + return cls._tvm_ffi_available + def can_build(self, sol: Solution) -> bool: - return sol.spec.language == SupportedLanguages.CUDA + return sol.spec.language == SupportedLanguages.CUDA and self._get_tvm_ffi_available()
74-75: Address unused parameter flagged by static analysis.The static analysis tool flags
cpp_filesas unused in_get_language(). While the current logic (return "cuda" if cuda_files else "cpp") is correct, the parameter is indeed unused. Consider whether this is intentional for API consistency or if the logic should be updated.If the parameter should remain for API consistency, add a comment or underscore prefix. Otherwise, the logic could be:
-def _get_language(self, cpp_files: List[str], cuda_files: List[str]) -> str: - return "cuda" if len(cuda_files) > 0 else "cpp" +def _get_language(self, cpp_files: List[str], cuda_files: List[str]) -> str: + if len(cuda_files) > 0: + return "cuda" + elif len(cpp_files) > 0: + return "cpp" + else: + # Should not reach here due to validation in _write_sources + return "cpp"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
flashinfer_bench/compile/builder.py(1 hunks)flashinfer_bench/compile/builders/__init__.py(1 hunks)flashinfer_bench/compile/builders/cuda_builder.py(6 hunks)flashinfer_bench/compile/builders/tvm_ffi_builder.py(1 hunks)flashinfer_bench/compile/registry.py(1 hunks)flashinfer_bench/compile/runnable.py(2 hunks)flashinfer_bench/data/definition.py(5 hunks)flashinfer_bench/data/trace.py(1 hunks)flashinfer_bench/utils.py(1 hunks)pyproject.toml(1 hunks)tests/compile/test_tvm_ffi_builder.py(1 hunks)tests/data/test_definition.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer_bench/data/trace.py
- flashinfer_bench/compile/builder.py
- tests/compile/test_tvm_ffi_builder.py
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer_bench/compile/builders/cuda_builder.py (3)
flashinfer_bench/utils.py (1)
is_cuda_available(57-59)flashinfer_bench/compile/builder.py (1)
BuildError(48-49)flashinfer_bench/compile/runnable.py (1)
Runnable(9-38)
flashinfer_bench/compile/builders/__init__.py (1)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
TVMFFIBuilder(22-143)
tests/data/test_definition.py (1)
flashinfer_bench/data/definition.py (2)
const_axes(194-202)var_axes(205-213)
flashinfer_bench/compile/registry.py (1)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
TVMFFIBuilder(22-143)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (4)
flashinfer_bench/compile/builder.py (8)
Builder(52-95)BuildError(48-49)create_pkg_name(33-45)can_build(59-61)_make_key(74-76)_make_closer(69-71)_build(64-66)build(78-85)flashinfer_bench/compile/runnable.py (2)
Runnable(9-38)TVMFFIRunnable(41-82)flashinfer_bench/data/solution.py (1)
SupportedLanguages(12-24)flashinfer_bench/env.py (1)
get_fib_cache_path(46-57)
flashinfer_bench/compile/runnable.py (2)
flashinfer_bench/data/definition.py (2)
get_var_values(239-279)get_output_shapes(343-363)flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(39-45)
🪛 Ruff (0.14.4)
flashinfer_bench/compile/builders/__init__.py
6-6: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer_bench/compile/builders/tvm_ffi_builder.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
71-71: Avoid specifying long messages outside the exception class
(TRY003)
74-74: Unused method argument: cpp_files
(ARG002)
81-83: Avoid specifying long messages outside the exception class
(TRY003)
93-93: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
131-131: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/data/definition.py
267-270: Avoid specifying long messages outside the exception class
(TRY003)
275-278: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
🔇 Additional comments (11)
flashinfer_bench/data/definition.py (3)
193-213: LGTM! API migration to properties is more Pythonic.The migration from
get_const_axes()andget_var_axes()methods toconst_axesandvar_axesproperties is a good improvement. Using properties for simple accessors aligns with Python conventions, and the@cached_propertydecorator ensures efficient computation.
215-237: Consistent API rename.The rename from
get_var_axes_bindingstovar_axes_bindingsmaintains consistency with the other property renames in this file.
239-279: Well-designed validation method.The new
get_var_valuesmethod provides comprehensive validation:
- Consistency checking when axes appear in multiple inputs
- Completeness checking to ensure all variable axes have values
- Clear error messages for debugging
The implementation is solid and the error messages are appropriately detailed despite static analysis hints.
tests/data/test_definition.py (1)
43-44: Tests correctly updated for new property API.The test assertions have been properly updated to use the new property-based API (
const_axesandvar_axesinstead of method calls).flashinfer_bench/compile/registry.py (1)
55-63: Registry ordering is clear; previous fallback concern was addressed.The builder priority is now explicit: Python > Triton > TVM-FFI > CUDA (pybind11). The past review comment regarding fallback semantics was marked as addressed in commit f6680fd.
flashinfer_bench/utils.py (1)
57-60: Good addition for centralized CUDA availability checking.This utility function provides a centralized way to check CUDA availability, which is now used by the CUDA builder. This improves consistency across the codebase.
flashinfer_bench/compile/builders/cuda_builder.py (1)
3-3: Excellent improvements to logging and error handling.The changes improve the CUDA builder's robustness:
- Centralized CUDA availability check via
is_cuda_available()- Enhanced logging for resource discovery failures
- Better type hints with
Optional[List[str]]- Clearer error messages
These changes make debugging easier and align with the broader TVM-FFI integration.
Also applies to: 10-10, 20-20, 24-24, 27-27, 61-66, 125-125
flashinfer_bench/compile/runnable.py (2)
11-11: Breaking change: meta parameter now required.The
metaparameter is no longer optional inRunnable.__init__. This is a breaking change but appears intentional since all builders in the codebase provide metadata. The change enforces that metadata must always be present.Also applies to: 19-19
41-82: TVMFFIRunnable implementation is well-designed.The destination-passing execution model is correctly implemented:
- Pre-allocates output tensors using Definition metadata
- Infers variable axis values from input shapes
- Handles device placement appropriately (defaults to "cpu" if no inputs)
- Returns single tensor or list based on output count
The implementation leverages the new
get_var_valuesandget_output_shapesmethods from Definition, demonstrating good integration with the updated API.flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
96-96: Confirm arg_order correctly includes outputs.The
arg_orderon line 96 correctly includes both inputs and outputs:list(defn.inputs.keys()) + list(defn.outputs.keys()). This aligns with the destination-passing style where outputs are pre-allocated and passed to the function. The past review concern has been properly addressed.
53-72: Core build logic is well-structured.The TVMFFIBuilder implementation follows good practices:
- Clear separation of concerns (source writing, compilation, module loading)
- Comprehensive error handling with BuildError exceptions
- Proper caching via build_path
- Correct integration with TVMFFIRunnable for destination-passing execution
The overall implementation is solid and ready for integration.
Also applies to: 86-104, 106-143
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces TVMFFIBuilder to integrate TVM-FFI for building CUDA/C++ kernels, which is a significant feature enhancement. The implementation includes caching to improve build times and refactors several parts of the codebase for better quality, such as using cached_property in the Definition class and centralizing CUDA availability checks. The new tests for TVMFFIBuilder are quite comprehensive. I've identified a few areas for improvement, primarily concerning robustness and clarity in the new builder and runnable classes. Specifically, there's a potential issue with handling source files in subdirectories, some fragility in device handling within TVMFFIRunnable, and a minor regression in the close method's robustness. Additionally, there are some inconsistencies in comments and docstrings regarding the TVM-FFI build function being used.
| if src_path.is_dir(): | ||
| raise BuildError(f"Source path is a directory: {src_path}") | ||
|
|
||
| src_path.write_text(src.content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line will raise a FileNotFoundError if src.path contains subdirectories (e.g., a/b/kernel.cu) because the parent directories are not created. Please ensure parent directories exist before writing the file. For example, you could add src_path.parent.mkdir(parents=True, exist_ok=True) before this line.
| ) | ||
| output_shapes = self._definition.get_output_shapes(var_values) | ||
| output_tensors: Dict[str, torch.Tensor] = {} | ||
| device = next(iter(kwargs.values())).device if len(kwargs) > 0 else "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to determine the device is a bit fragile. It assumes the first value in kwargs is a tensor and that all input tensors are on the same device. This can lead to errors if kwargs contains non-tensor values or tensors on different devices. A more robust approach would be to iterate through all tensor inputs, verify they are on the same device, and then use that device. For example:
devices = {v.device for v in kwargs.values() if hasattr(v, "device")}
if len(devices) > 1:
raise ValueError("All input tensors must be on the same device")
device = devices.pop() if devices else "cpu"| 2. If not, compile with tvm_ffi.cpp.build_inline() to cache | ||
| 3. Load with tvm_ffi.load_module() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring states that tvm_ffi.cpp.build_inline() is used for compilation. However, the implementation in _build (line 117) uses tvm_ffi.cpp.build(). The comment on line 116 is also misleading. Please update the docstring and the comment to be consistent with the implementation to avoid confusion.
| def close(self) -> None: | ||
| if self._closer: | ||
| self._closer() | ||
| self._closer = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The close method in the base Runnable class uses a try...finally block to ensure self._closer is set to None even if an exception occurs. This implementation in TVMFFIRunnable is missing the try...finally, which is a slight regression in robustness. If self._closer() raises an exception, self._closer won't be set to None, and it might be called again later, which is not idempotent. Please restore the try...finally block for robustness, like so:
def close(self) -> None:
if self._closer:
try:
self._closer()
finally:
self._closer = None
This PR adds the TVMFFIBuilder, initializing the integration of TVM-FFI.
Summary by CodeRabbit
New Features
Tests
Refactor
Chores